function [xs,ws,npols] = sop2(param,errbnd)
%
% Given the param for the Laplace transformed memory function 
% $\tilde{M}(u)$ for the TPL model and the desired precision
% errbnd, this functions returns to the user
% the sum-of-poles approximation for \tilde{M}(u) - n
% or sum-of-exponential approximation for M(t) - n\delta(t).
%
% That is, \tilde{M)(u) \approx n + \sum_i ws(i)/(u-xs(i))
% and      M(t) \approx n\delta(t) + \sum_i ws(i) exp(xs(i) t),
% where n is the normalization factor in \tilde{M}(u). 
%
% Note that \lim_{u\rightarrow \infty} \tilde{M}(u) = n, this is
% why we need to subtract n from \tilde{M}(u).
%
% Written by: Shidong Jiang
%             Department of Mathematical Sciences
%             New Jersey Institute of Technology
%             Newark, NJ 07102 USA
%
% Last Modified on July 22, 2017
%
% Please send your questions to shidong.jiang@njit.edu
%
%
% The following 8 parameter sets have all passed the test 
% with errbnd = 1e-6 or 1e-4.
%
% param1 = [1.5111  -2.0152e+00   1.1077e+01];
% param2 = [1.4393  -2.9943e+00   1.0554e+01];
% param3 = [1.3726  -2.4137e+00   1.1541e+01];
% param4 = [1.5968  -1.6626e+00   1.0396e+01];
% param5 = [1.3605  -9.8386e-01   1.1942e+01];
% param6 = [1.9600   8.1954e-01   5.0000e+00];
% param7 = [1.9600   8.1954e-01   2.0000e+00];
% param8 = [1.6000   8.1954e-01   5.0000e+00];
%
% Usage example: [xs,ws,npols] = sop2(param8,1e-6);
%
% close all
format short e
format compact
%
% (a). find the end point b of the interval [-b,b] along the 
% imaginary axis on which the least squares fitting 
% is applied to find the initial weights 
%
b = 1e8;
for i=1:30
f = real(mem(1i*b,param));
if (abs(f)<1e-6)
    b = b/2;
else
    break;
end
end
b = b*1.6;
a = 1e-2;
%
% (b). find sampling points that will sufficiently resolve the given
% function using recursive binary refinement.
%
%tic
xs = samppts(@(x) mem(1i*x,param),a,b,1e-9);
% ! WARNING: the current implementation of the MEM function
% suffers from cancellation error near the origin. This is why
% we treat the near interval [1e-6, 1e-2] separately!!!!
x0 = samppts(@(x) mem(1i*x,param),1e-6,a,5e-8);
xs = [x0 xs];
x = xs';
%toc
npts = length(x);
%
% calculate the function values on these sampling points
%

% tic
rhs = mem(1i*x,param);
x = [flipud(-x); x];
rhs = [conj(flipud(rhs)); rhs];
% toc

res = 1;
resmin = res;
M=3;N=1;
if log10(b)>5, M=4; N=2; end
% tic

% increase niter, minter to make the algorithm more robust but slower
% or decrease niter, minter to make the algorithm less robust but faster.
for niter=1:5,
  for miter=1:6,
    ne=24;
    [xs,npols] = poleinit(M,N,ne);

    [X,Y]=meshgrid(xs,1i*x);
%
% (c). construct the matrix for finding the soe weights
%
    A = 1./(Y-X);
%
% (d). use least squares to solve the linear system and find the weights
%
    reps = 5d-13;
    [ws, res] = myls(A,rhs,reps);

    if resmin > res,
      resmin = res;
      MM = M;
      NN = N;
    end

    if res<1e-9,
      break;
    end
    M=M+1;
  end
  if res<1e-9,
    disp(['Least squares residual is ',num2str(res)])
    %disp(['M = ',num2str(M), ' N = ', num2str(N)])
    break;
  end
  M=3;
  if log10(b)>5, M=4;end
  N=N+1;
end

if resmin>1e-9,
  ne = 24;
% 
  [xs,npols] = poleinit(MM,NN,ne);

  [X,Y]=meshgrid(xs,1i*x);
  A = 1./(Y-X);
  reps = 5d-13;
  [ws, res] = myls(A,rhs,reps);
  disp(['Least squares residual is ',num2str(res)])
%  disp(['M = ',num2str(MM), ' N = ', num2str(NN)])
end

% toc

%
% Use the square root method to reduce the number of poles
%
wsold=real(ws);
xsold=xs;
[ws,xs,bound]=squareroot(wsold,xsold,errbnd);
npols=length(xs);
%
% test the accuracy on the interval 
%

% tic
[xs,ws,rerr]=esterror(xs,ws,param);
% toc
disp(['Estimated relative L2 error is ', num2str(rerr)])
disp(['number of poles is ',num2str(npols)]);
% semilogx(real(xs),imag(xs),'r.')
% title('Pole locations')
% drawnow
%
%
%
function [xs,ws,rerr] = esterror(xs,ws,param)
% m is the number of testing points on the positive imaginary axis
  m = 2000;

  eend = 8;
  estart = -6;
  x = 10.^linspace(estart,eend,m);
% comment the following line out to save some testing time 
% since the evaluation of the mem function is rather expensive
% x = [fliplr(-x) x]; 
  n = length(x);

  ftrue = mem(1i*x,param);
  fcomp = zeros(1,n);
  for i = 1:n,
    fcomp(i) = ws.'*(1./(1i*x(i)-xs));
  end

  xs1=real(xs);
  ws1=real(ws);
  fcomp1 = zeros(1,n);
  for i=1:n,
    fcomp1(i) = ws1.'*(1./(1i*x(i)-xs1));
  end
    
%figure; semilogx(x,real(ftrue),'r.',x,imag(ftrue),'b')
%hold on;
%semilogx(x,real(fcomp),'g.',x,imag(fcomp),'y-')

  rerr = norm(ftrue-fcomp)/norm(ftrue);
  rerr1 = norm(ftrue-fcomp1)/norm(ftrue);
  if rerr1/rerr < 5,
    disp('poles and weights are real');
    xs=xs1;
    ws=ws1;
    rerr=rerr1;
  end
%
%
%
function [xs,npols] = poleinit(M,N,ne)
%
% ne is the number of nodes on each subinterval
% 
  xs = linspace(0,4^(-M),ne);
  xs = xs(2:end-1);
  for i=-M:N,
    xs = [xs linspace(4^i,4^(i+1),ne)];
    xs = xs(1:end-1);
  end
  xs = -xs;
  npols = length(xs);
%
% Chebyshev sample functions
%
function xs = samppts(fun,A,B,errbnd)
  n = 16;
  ifsplit = [1];
  endpts = [A B];
  notdone = sum(ifsplit>0);
  while notdone>0,
    ind = find(ifsplit);
    for i=ind,
      a = endpts(i);
      b = endpts(i+1);
      c = (a+b)/2;
      endpts = [endpts(1:i) c endpts(i+1:end)];
      sl = split(fun,a,c,n,errbnd);
      sr = split(fun,c,b,n,errbnd);
      ifsplit = [ifsplit(1:i-1) sl sr ifsplit(i+1:end)];
    end
    notdone = sum(ifsplit>0);
  end
  xs=[];
  noversample=2;
  for i=1:length(endpts)-1,
    a=endpts(i);
    b=endpts(i+1);
    xs=[xs chnodc(a,b,n*noversample)];
  end
%
%
%
function ifsplit = split(fun,a,b,n,errbnd)             
  xs = chnodc(a,b,n);
  fs = fun(xs);
  texp = chexfc(fs,n);
  if (abs(texp(n))+abs(texp(n-1)))/(abs(texp(1))+abs(texp(2))) < errbnd ...
    ||   ( abs(texp(1)) + abs(texp(2)) ) < errbnd
    ifsplit = 0;
  else
    ifsplit = 1;
  end
%      
%
function xs = chnodc(a,b,n)
% return n Chebyshev nodes of the first kind on [a,b]
  xs = (b-a)/2*cos((2*(n:-1:1)-1)*pi/2/n)+(b+a)/2;
%
%
%
function coefs = chexfc(fval,n)
% given function values at n Chebynodes, return to the user
% the coefficients of its Chebyshev expansion.
  w = [fliplr(fval) fval];
  w = fft(w);
  coefs = w(1:n).*exp(-1i*pi/2/n*(0:n-1))/n;
  coefs(1) = coefs(1)/2;
%
%
%
function [x,res]=myls(A,b,eps)
  [m,n]=size(A);
  [U,S,V]=svd(A,0);
  s=diag(S);
  r=sum(s>eps);
  x=zeros(n,1);
  for i=1:r
    x=x+(U(:,i)'*b)/s(i)*V(:,i);
  end
  res = norm(A*x-b)/norm(b);
%
% 
%
function [wn, pn, bound] = squareroot(w,p,errbnd)
  [A,B,C]=rat2mod(w,p);
  [Ahat,Bhat,Chat,bound]=squareroot0(A,B,C,errbnd);
  [wn,pn]=mod2rat(Ahat,Bhat,Chat);
%    
%
%
function [Ahat,Bhat,Chat,bound] = squareroot0(A,B,C,errbnd)
  Rr=lyapchol(A,B);
  Lr=Rr';
  Ro=lyapchol(A',C');
  Lo=Ro';
  [U,S,V]=svd(Lo'*Lr);
  H=diag(diag(S).^(-1/2));
  sigma=diag(S);
  k=sum(sigma>errbnd*sigma(1));
  bound=2*sum((sigma<errbnd).*sigma);
  H=H(:,1:k);
  SL=Lo*U*H;
  SR=Lr*V*H;
  Ahat=SL'*A*SR;
  Bhat=SL'*B;
  Chat=C*SR;
%
%
%
function [A,B,C]=rat2mod(w,p)
  A=diag(p);
  B=sqrt(w);
  C=B.';
%
%
%
function [w,p]=mod2rat(A,B,C)
  [V,D]=eig(A);
  p=diag(D);
  Ct=C*V;
  Bt=V\B;
  w=(Ct.').*Bt;
